package com.sifou.courses.limit;

import com.google.common.base.CaseFormat;
import com.google.common.util.concurrent.RateLimiter;
import lombok.extern.slf4j.Slf4j;
import org.redisson.RedissonObject;
import org.redisson.api.*;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.event.ContextRefreshedEvent;
import org.springframework.core.annotation.AnnotationUtils;

import javax.annotation.Resource;
import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Map;

/**
 * @author liuzhongxu
 * @date 2020/10/10
 */
@Slf4j
@Configuration
public class RateLimitConfig implements ApplicationListener<ContextRefreshedEvent> {

    @Resource
    private RateLimitMapping rateLimitMapping;

    @Resource
    private RedissonClient redissonClient;

    @Override
    public void onApplicationEvent(ContextRefreshedEvent event) {
        // 所有bean
        Map<String, Object> beansOfType = event.getApplicationContext().getBeansOfType(Object.class);
        for (Map.Entry entry : beansOfType.entrySet()) {

            Object bean = entry.getValue();
            String key = CaseFormat.LOWER_CAMEL.to(CaseFormat.UPPER_CAMEL, entry.getKey().toString());

            // 类中的方法
            Method[] declaredMethods = bean.getClass().getDeclaredMethods();
            for (Method method : declaredMethods) {
                // 过滤限流注解
                RateLimit limit = AnnotationUtils.findAnnotation(method, RateLimit.class);
                if (limit == null) {
                    continue;
                }

                // 初始化令牌桶
                initRateLimit(limit, String.format("%s.%s", key, method.getName()));
            }
        }
    }

    /**
     * 初始化令牌桶
     * @param limit
     * @param name
     */
    private void initRateLimit(RateLimit limit, String name) {
        log.debug("Init RateLimit name={} limitType={} limitNum={}", name, limit.limitType(), limit.limitNum());
        // 初始化单机令牌桶
        if (RateLimit.LimitType.single.equals(limit.limitType())) {
            RateLimiter rateLimiter = RateLimiter.create(limit.limitNum());
            rateLimitMapping.getSingleLimiterMap().put(name, rateLimiter);
            return;
        }
        // 初始化集群令牌桶
        if (RateLimit.LimitType.cluster.equals(limit.limitType())) {
            RRateLimiter rateLimiter = redissonClient.getRateLimiter(name);
            /**
             * RateType.OVERALL,    所有客户端加总限流，就是集群下所有的流量
             * RateType.PER_CLIENT; 每个客户端单独计算流量，每台机器的流量
             */
            boolean rate = rateLimiter.trySetRate(RateType.OVERALL, limit.limitNum(), 1, RateIntervalUnit.SECONDS);
            // 更新限流阈值
            if (!rate && rateLimiter.getConfig().getRate() - limit.limitNum() != 0) {
                redissonClient.getScript().evalAsync(RScript.Mode.READ_WRITE,
                    "redis.call('del', KEYS[1]);"
                            + "redis.call('del', KEYS[2]);"
                            + "redis.call('del', KEYS[3]);"
                            + "redis.call('hsetnx', KEYS[1], 'rate', ARGV[1]);"
                            + "redis.call('hsetnx', KEYS[1], 'interval', ARGV[2]);"
                            + "return redis.call('hsetnx', KEYS[1], 'type', ARGV[3]);",
                    RScript.ReturnType.BOOLEAN,
                    Arrays.asList(name, RedissonObject.suffixName(name, "value"), RedissonObject.suffixName(name, "permits")),
                    limit.limitNum(), 1000, RateType.OVERALL.ordinal());
                return;
            }
        }
        return;
    }
}
